-
Notifications
You must be signed in to change notification settings - Fork 74.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to cuDNN CTC loss #32302
Add support to cuDNN CTC loss #32302
Conversation
@pkanwar23 would you please help to find someone to review this? |
@chsigg could you help to take a look at this? Thanks |
std::vector<int> *labels_lengths) { | ||
const T* h_in = labels_indices->flat<T>().data(); | ||
for(int i = 0; i < num_indices; i++) { | ||
T key = h_in[i * 2]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const ref
// takes the ownership of the underlying memory. The expectation is that the | ||
// memory should be alive for the span of the cudnnCTCLoss itself. | ||
template <typename T> | ||
class CudnnCtcLossAllocatorInTemp : public ScratchAllocator { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is identical to e.g. CudnnBatchNormAllocatorInTemp in fused_batch_norm_op.cc. Can you consolidate them to a single location instead of duplicating code, please?
@@ -62,6 +62,43 @@ REGISTER_OP("CTCLoss") | |||
return Status::OK(); | |||
}); | |||
|
|||
REGISTER_OP("CTCLossV2") | |||
.Input("inputs: float") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the CuDNN implementation support types other than float? If so, we should also support them here.
#31164 added support for double for CTCLossOp, for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, CuDNN only support the float CTCLoss.
Adding Tim Shen to review the stream executor bits. |
@sanjoy @alextp , I have replaced the previous environment variable with the implementation selector (Thx @qlzh727 for helping me out with some test cases.) Now, we don't need the env var to control if cuDNN is used or not. The runtime can automatically determine that if GPU is available or not. I added another python function to contain this new implement (ie. ctc_loss_v3), which is only available in TF2. Please help me find the reviewers to review this part. Thx. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a couple of minor nits and we can approve.
Thanks!
tensorflow/python/ops/ctc_ops.py
Outdated
@@ -42,6 +46,28 @@ | |||
from tensorflow.python.util import nest | |||
from tensorflow.python.util.tf_export import tf_export | |||
|
|||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The linter will complain; standard python imports need to be above all others
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Removed this unused import.
@@ -0,0 +1,71 @@ | |||
op { | |||
graph_op_name: "CTCLossV2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a line here saying "visibility: HIDDEN"; this will prevent the generation of a tf.ctc_loss_v2 API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Done.
@@ -576,6 +576,10 @@ tf_module { | |||
name: "cosh" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now you can revert this file to make the API tests pass again
@@ -1068,6 +1068,10 @@ tf_module { | |||
name: "cross" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now you can revert this file to make the API tests pass again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. PTAL.
Anything else I can do? Thx. |
Thanks for checking. We should be good. I'm looking at why it hasn't merged. |
It was waiting for an approval from me for some reason. Should be good to go now. |
Thx for the update. |
PiperOrigin-RevId: 290387603 Change-Id: I28491f42a4559a9f79bd6a7b73d8e6b670f55368
This PR supports CUDNN CTC Loss as the backend of ctc_loss_v2()
Users need to use the environment variable TF_CUDNN_CTC_LOSS=1
Why can we make it default (not using TF_CUDNN_CTC_LOSS)?
What is the logic in the new ctc_loss_v2()?
fyi @nluehr